{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import timeit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0.],\n",
       "        [0., 0., 0.]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.empty(2, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.7508, 0.0285, 0.7523],\n",
       "        [0.1396, 0.7432, 0.0843]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.rand(2, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0, 0],\n",
       "        [0, 0, 0]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.zeros(2, 3, dtype=torch.long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.arange(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "array = [[1.0, 3.8, 2.1], [8.6, 4.0, 2.4]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000, 3.8000, 2.1000],\n",
       "        [8.6000, 4.0000, 2.4000]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.tensor(array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "array = np.array([[1.0, 3.8, 2.1], [8.6, 4.0, 2.4]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000, 3.8000, 2.1000],\n",
       "        [8.6000, 4.0000, 2.4000]], dtype=torch.float64)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.from_numpy(array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "ename": "AssertionError",
     "evalue": "Torch not compiled with CUDA enabled",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[15], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m#torch.rand(2, 3).cuda()\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrand\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcuda\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/cuda/__init__.py:293\u001b[0m, in \u001b[0;36m_lazy_init\u001b[0;34m()\u001b[0m\n\u001b[1;32m    288\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m    289\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot re-initialize CUDA in forked subprocess. To use CUDA with \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    290\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmultiprocessing, you must use the \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mspawn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m start method\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    291\u001b[0m     )\n\u001b[1;32m    292\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(torch\u001b[38;5;241m.\u001b[39m_C, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_cuda_getDeviceCount\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 293\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTorch not compiled with CUDA enabled\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    294\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _cudart \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    295\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(\n\u001b[1;32m    296\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlibcudart functions unavailable. It looks like you have a broken build?\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    297\u001b[0m     )\n",
      "\u001b[0;31mAssertionError\u001b[0m: Torch not compiled with CUDA enabled"
     ]
    }
   ],
   "source": [
    "#torch.rand(2, 3).cuda()\n",
    "torch.rand(2, 3, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.tensor([1, 2, 3], dtype=torch.double)\n",
    "y = torch.tensor([4, 5, 6], dtype=torch.double)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([5., 7., 9.], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "print(x + y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-3., -3., -3.], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "print(x - y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 4., 10., 18.], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "print(x * y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.2500, 0.4000, 0.5000], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "print(x / y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(32., dtype=torch.float64)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.dot(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.8415, 0.9093, 0.1411], dtype=torch.float64)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.sin()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2.7183,  7.3891, 20.0855], dtype=torch.float64)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.exp()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(3.5000, dtype=torch.float64)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2.5000, 3.5000, 4.5000], dtype=torch.float64)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.mean(dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2., 5.], dtype=torch.float64)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.mean(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[2.5000, 3.5000, 4.5000]], dtype=torch.float64)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.mean(dim=0, keepdim=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[2.],\n",
       "        [5.]], dtype=torch.float64)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.mean(dim=1, keepdim=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.tensor([[1, 2, 3], [ 4,  5,  6]], dtype=torch.double)\n",
    "y = torch.tensor([[7, 8, 9], [10, 11, 12]], dtype=torch.double)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.,  2.,  3.],\n",
       "        [ 4.,  5.,  6.],\n",
       "        [ 7.,  8.,  9.],\n",
       "        [10., 11., 12.]], dtype=torch.float64)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cat((x, y), dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.,  2.,  3.,  7.,  8.,  9.],\n",
       "        [ 4.,  5.,  6., 10., 11., 12.]], dtype=torch.float64)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cat((x, y), dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.tensor([2.])\n",
    "y = torch.tensor([3.])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([5.])\n"
     ]
    }
   ],
   "source": [
    "z = (x + y) * (y - 2)\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "80.6862203619985\n"
     ]
    }
   ],
   "source": [
    "M = torch.rand(1000, 1000)\n",
    "print(timeit.timeit(lambda: M.mm(M).mm(M), number=5000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n"
     ]
    }
   ],
   "source": [
    "print(torch.cuda.is_available())\n",
    "#N = torch.rand(1000, 1000).cuda()\n",
    "#print(timeit.timeit(lambda: N.mm(N).mm(N), number=5000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "# z = (x + y) * (y - 2)\n",
    "x = torch.tensor([2.], requires_grad=True)\n",
    "y = torch.tensor([3.], requires_grad=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([5.], grad_fn=<MulBackward0>)\n"
     ]
    }
   ],
   "source": [
    "z = (x + y) * (y - 2)\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1.]) tensor([6.])\n"
     ]
    }
   ],
   "source": [
    "z.backward()\n",
    "print(x.grad, y.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1, 2, 3, 4, 5, 6]) torch.Size([6]) True\n"
     ]
    }
   ],
   "source": [
    "x = torch.tensor([1, 2, 3, 4, 5, 6])\n",
    "print(x, x.shape, x.is_contiguous())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2, 3],\n",
       "        [4, 5, 6]])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.view(2, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2],\n",
       "        [3, 4],\n",
       "        [5, 6]])"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.view(3, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2, 3],\n",
       "        [4, 5, 6]])"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.view(-1, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2, 3],\n",
       "        [4, 5, 6]])"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([[1, 2, 3], [4, 5, 6]])\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 4],\n",
       "        [2, 5],\n",
       "        [3, 6]])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.transpose(0, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[1, 2, 3],\n",
      "         [4, 5, 6]]]) torch.Size([1, 2, 3])\n"
     ]
    }
   ],
   "source": [
    "x = torch.tensor([[[1, 2, 3], [4, 5, 6]]])\n",
    "print(x, x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[1, 4]],\n",
      "\n",
      "        [[2, 5]],\n",
      "\n",
      "        [[3, 6]]]) torch.Size([3, 1, 2])\n"
     ]
    }
   ],
   "source": [
    "x = x.permute(2, 0, 1)\n",
    "print(x, x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.arange(1, 4).view(3, 1)\n",
    "y = torch.arange(4, 6).view(1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1],\n",
      "        [2],\n",
      "        [3]]) tensor([[4, 5]])\n"
     ]
    }
   ],
   "source": [
    "print(x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[5, 6],\n",
      "        [6, 7],\n",
      "        [7, 8]])\n"
     ]
    }
   ],
   "source": [
    "print(x + y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0,  1,  2,  3],\n",
       "        [ 4,  5,  6,  7],\n",
       "        [ 8,  9, 10, 11]])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.arange(12).view(3, 4)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(7)"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x[1, 3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([4, 5, 6, 7])"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 4,  5,  6,  7],\n",
       "        [ 8,  9, 10, 11]])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x[1:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2,  6, 10])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x[:, 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 2,  3],\n",
       "        [ 6,  7],\n",
       "        [10, 11]])"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x[:, 2:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  0,   1, 100, 100],\n",
       "        [  4,   5, 100, 100],\n",
       "        [  8,   9, 100, 100]])"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x[:, 2:4] = 100\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4])"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.tensor([1, 2, 3, 4])\n",
    "a.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2, 3, 4]])"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.unsqueeze(a, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1],\n",
       "        [2],\n",
       "        [3],\n",
       "        [4]])"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = a.unsqueeze(dim=1)\n",
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1, 2, 3, 4])"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b.squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import IterableDataset, DataLoader\n",
    "from torch.utils.data import get_worker_info\n",
    "import math\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyIterableDataset(IterableDataset):\n",
    "\n",
    "    def __init__(self, start, end):\n",
    "        super(MyIterableDataset).__init__()\n",
    "        assert end > start\n",
    "        self.start = start\n",
    "        self.end = end\n",
    "\n",
    "    def __iter__(self):\n",
    "        return iter(range(self.start, self.end))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([3]), tensor([4]), tensor([5]), tensor([6])]\n"
     ]
    }
   ],
   "source": [
    "ds = MyIterableDataset(start=3, end=7) # [3, 4, 5, 6]\n",
    "# Single-process loading\n",
    "print(list(DataLoader(ds, num_workers=0)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "def worker_init_fn(worker_id):\n",
    "    worker_info = get_worker_info()\n",
    "    dataset = worker_info.dataset  # the dataset copy in this worker process\n",
    "    overall_start = dataset.start\n",
    "    overall_end = dataset.end\n",
    "    # configure the dataset to only process the split workload\n",
    "    per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))\n",
    "    worker_id = worker_info.id\n",
    "    dataset.start = overall_start + worker_id * per_worker\n",
    "    dataset.end = min(dataset.start + per_worker, overall_end)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"<string>\", line 1, in <module>\n",
      "  File \"/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/spawn.py\", line 116, in spawn_main\n",
      "Traceback (most recent call last):\n",
      "  File \"<string>\", line 1, in <module>\n",
      "  File \"/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/spawn.py\", line 116, in spawn_main\n",
      "    exitcode = _main(fd, parent_sentinel)\n",
      "  File \"/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/spawn.py\", line 126, in _main\n",
      "    exitcode = _main(fd, parent_sentinel)\n",
      "      File \"/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/spawn.py\", line 126, in _main\n",
      "self = reduction.pickle.load(from_parent)\n",
      "AttributeError: Can't get attribute 'MyIterableDataset' on <module '__main__' (built-in)>\n",
      "    self = reduction.pickle.load(from_parent)\n",
      "AttributeError: Can't get attribute 'MyIterableDataset' on <module '__main__' (built-in)>\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "DataLoader worker (pid(s) 85295, 85296) exited unexpectedly",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/utils/data/dataloader.py:1133\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m   1132\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1133\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_queue\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1134\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n",
      "File \u001b[0;32m/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/queues.py:113\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m    112\u001b[0m timeout \u001b[38;5;241m=\u001b[39m deadline \u001b[38;5;241m-\u001b[39m time\u001b[38;5;241m.\u001b[39mmonotonic()\n\u001b[0;32m--> 113\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_poll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m    114\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m Empty\n",
      "File \u001b[0;32m/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/connection.py:262\u001b[0m, in \u001b[0;36m_ConnectionBase.poll\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    261\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_readable()\n\u001b[0;32m--> 262\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_poll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/connection.py:429\u001b[0m, in \u001b[0;36mConnection._poll\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    428\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_poll\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout):\n\u001b[0;32m--> 429\u001b[0m     r \u001b[38;5;241m=\u001b[39m \u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    430\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mbool\u001b[39m(r)\n",
      "File \u001b[0;32m/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/connection.py:936\u001b[0m, in \u001b[0;36mwait\u001b[0;34m(object_list, timeout)\u001b[0m\n\u001b[1;32m    935\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m--> 936\u001b[0m     ready \u001b[38;5;241m=\u001b[39m \u001b[43mselector\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    937\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m ready:\n",
      "File \u001b[0;32m/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/selectors.py:416\u001b[0m, in \u001b[0;36m_PollLikeSelector.select\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    415\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 416\u001b[0m     fd_event_list \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_selector\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpoll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    417\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mInterruptedError\u001b[39;00m:\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/utils/data/_utils/signal_handling.py:66\u001b[0m, in \u001b[0;36m_set_SIGCHLD_handler.<locals>.handler\u001b[0;34m(signum, frame)\u001b[0m\n\u001b[1;32m     63\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mhandler\u001b[39m(signum, frame):\n\u001b[1;32m     64\u001b[0m     \u001b[38;5;66;03m# This following call uses `waitid` with WNOHANG from C side. Therefore,\u001b[39;00m\n\u001b[1;32m     65\u001b[0m     \u001b[38;5;66;03m# Python can still get and update the process status successfully.\u001b[39;00m\n\u001b[0;32m---> 66\u001b[0m     \u001b[43m_error_if_any_worker_fails\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     67\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m previous_handler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[0;31mRuntimeError\u001b[0m: DataLoader worker (pid 85295) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[76], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# Directly doing multi-process loading\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mDataLoader\u001b[49m\u001b[43m(\u001b[49m\u001b[43mds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_workers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mworker_init_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mworker_init_fn\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m)\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/utils/data/dataloader.py:631\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    628\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    629\u001b[0m     \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m    630\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()  \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 631\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m    633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    634\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m    635\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/utils/data/dataloader.py:1329\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1326\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_data(data)\n\u001b[1;32m   1328\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_shutdown \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m-> 1329\u001b[0m idx, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1330\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m   1331\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable:\n\u001b[1;32m   1332\u001b[0m     \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/utils/data/dataloader.py:1295\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1291\u001b[0m     \u001b[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[1;32m   1292\u001b[0m     \u001b[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[1;32m   1293\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1294\u001b[0m     \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m-> 1295\u001b[0m         success, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_try_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1296\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[1;32m   1297\u001b[0m             \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
      "File \u001b[0;32m~/Library/Python/3.9/lib/python/site-packages/torch/utils/data/dataloader.py:1146\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m   1144\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(failed_workers) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m   1145\u001b[0m     pids_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mstr\u001b[39m(w\u001b[38;5;241m.\u001b[39mpid) \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m failed_workers)\n\u001b[0;32m-> 1146\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDataLoader worker (pid(s) \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpids_str\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) exited unexpectedly\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m   1147\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, queue\u001b[38;5;241m.\u001b[39mEmpty):\n\u001b[1;32m   1148\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)\n",
      "\u001b[0;31mRuntimeError\u001b[0m: DataLoader worker (pid(s) 85295, 85296) exited unexpectedly"
     ]
    }
   ],
   "source": [
    "# Directly doing multi-process loading\n",
    "print(list(DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "from torchvision.io import read_image\n",
    "from torch.utils.data import Dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 可以看到，我们实现了 __init__()、__len__() 和 __getitem__() 三个函数，其中：\n",
    "# __init__() 初始化数据集参数，这里设置了图像的存储目录、标签（通过读取标签 csv 文件）以及样本和标签的数据转换函数；\n",
    "# __len__() 返回数据集中样本的个数；\n",
    "# __getitem__() 映射型数据集的核心，根据给定的索引 idx 返回样本。这里会根据索引从目录下通过 read_image 读取图片和从 csv 文件中读取图片标签，并且返回处理后的图像和标签。\n",
    "class CustomImageDataset(Dataset):\n",
    "    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None) -> None:\n",
    "        self.img_labels = pd.read_csv(annotations_file)\n",
    "        self.img_dir = img_dir\n",
    "        self.transform = transform\n",
    "        self.target_transform = target_transform\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.img_labels)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])\n",
    "        image = read_image(img_path)\n",
    "        label = self.img_labels.iloc[idx, 1]\n",
    "        if self.transform:\n",
    "            image = self.transform(image)\n",
    "        if self.target_transform:\n",
    "            label = self.target_transform(label)\n",
    "        return image, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.transforms import ToTensor\n",
    "from torchvision import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n",
      "Using downloaded and verified file: /Users/mac/Downloads/datasets/FashionMNIST/raw/train-images-idx3-ubyte.gz\n",
      "Extracting /Users/mac/Downloads/datasets/FashionMNIST/raw/train-images-idx3-ubyte.gz to /Users/mac/Downloads/datasets/FashionMNIST/raw\n",
      "\n",
      "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n",
      "Using downloaded and verified file: /Users/mac/Downloads/datasets/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n",
      "Extracting /Users/mac/Downloads/datasets/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /Users/mac/Downloads/datasets/FashionMNIST/raw\n",
      "\n",
      "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n",
      "Using downloaded and verified file: /Users/mac/Downloads/datasets/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n",
      "Extracting /Users/mac/Downloads/datasets/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /Users/mac/Downloads/datasets/FashionMNIST/raw\n",
      "\n",
      "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n",
      "Using downloaded and verified file: /Users/mac/Downloads/datasets/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n",
      "Extracting /Users/mac/Downloads/datasets/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /Users/mac/Downloads/datasets/FashionMNIST/raw\n",
      "\n"
     ]
    }
   ],
   "source": [
    "training_data = datasets.FashionMNIST(\n",
    "    root=\"/Users/mac/Downloads/datasets\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=ToTensor()\n",
    ")\n",
    "test_data = datasets.FashionMNIST(\n",
    "    root=\"/Users/mac/Downloads/datasets\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=ToTensor()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([60000, 28, 28])\n",
      "torch.Size([10000, 28, 28])\n"
     ]
    }
   ],
   "source": [
    "len(training_data),len(test_data)\n",
    "print(training_data.data.shape)\n",
    "print(test_data.data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)\n",
    "test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([60000, 28, 28])"
      ]
     },
     "execution_count": 115,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataloader.dataset.data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Feature batch shape: torch.Size([64, 1, 28, 28])\n",
      "Feature batch shape: torch.Size([64, 1, 28, 28])\n",
      "Labels batch shape: torch.Size([64])\n"
     ]
    }
   ],
   "source": [
    "train_features, train_labels = next(iter(train_dataloader))\n",
    "print(f\"Feature batch shape: {train_features.shape}\")\n",
    "print(f\"Feature batch shape: {train_features.size()}\")\n",
    "print(f\"Labels batch shape: {train_labels.size()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 28, 28])\n",
      "torch.Size([28, 28])\n",
      "Label: 9\n"
     ]
    }
   ],
   "source": [
    "print(train_features[0].shape)\n",
    "img = train_features[0].squeeze()\n",
    "label = train_labels[0]\n",
    "print(img.shape)\n",
    "print(f\"Label: {label}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import SequentialSampler, RandomSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_sampler = RandomSampler(training_data)\n",
    "test_sampler = SequentialSampler(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = DataLoader(training_data, batch_size=64, sampler=train_sampler)\n",
    "test_dataloader = DataLoader(test_data, batch_size=64, sampler=test_sampler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Feature batch shape: torch.Size([64, 1, 28, 28])\n",
      "Labels batch shape: torch.Size([64])\n",
      "Feature batch shape: torch.Size([64, 1, 28, 28])\n",
      "Labels batch shape: torch.Size([64])\n"
     ]
    }
   ],
   "source": [
    "train_features, train_labels = next(iter(train_dataloader))\n",
    "print(f\"Feature batch shape: {train_features.size()}\")\n",
    "print(f\"Labels batch shape: {train_labels.size()}\")\n",
    "test_features, test_labels = next(iter(test_dataloader))\n",
    "print(f\"Feature batch shape: {test_features.size()}\")\n",
    "print(f\"Labels batch shape: {test_labels.size()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.tensor([[1, 2, 3], [ 4,  5,  6]], dtype=torch.double)\n",
    "y = torch.tensor([[7, 8, 9], [10, 11, 12]], dtype=torch.double)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[1., 2., 3.],\n",
       "         [4., 5., 6.]], dtype=torch.float64),\n",
       " tensor([[ 7.,  8.,  9.],\n",
       "         [10., 11., 12.]], dtype=torch.float64))"
      ]
     },
     "execution_count": 119,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1., 2., 3., 4., 5., 6.], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "flatten = nn.Flatten(start_dim=0, end_dim=-1)\n",
    "print(flatten(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cpu device\n",
      "NeuralNetwork(\n",
      "  (flatten): Flatten(start_dim=1, end_dim=-1)\n",
      "  (linear_relu_stack): Sequential(\n",
      "    (0): Linear(in_features=784, out_features=512, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=512, out_features=256, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=256, out_features=10, bias=True)\n",
      "    (5): Dropout(p=0.2, inplace=False)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(f'Using {device} device')\n",
    "\n",
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NeuralNetwork, self).__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            nn.Linear(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 256),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(256, 10),\n",
    "            nn.Dropout(p=0.2)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.flatten(x)\n",
    "        #print(\"NeuralNetwork forward \\n\", x.shape)\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        return logits\n",
    "\n",
    "model = NeuralNetwork().to(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([28, 28])\n",
      "NeuralNetwork forward \n",
      " torch.Size([4, 784])\n",
      "torch.Size([4, 10]) tensor([[0.1004, 0.0959, 0.0997, 0.1062, 0.0913, 0.0891, 0.1031, 0.1014, 0.0977,\n",
      "         0.1151],\n",
      "        [0.1003, 0.0954, 0.0993, 0.1056, 0.0994, 0.0993, 0.0916, 0.1023, 0.0993,\n",
      "         0.1075],\n",
      "        [0.1021, 0.0967, 0.1027, 0.1037, 0.0939, 0.0890, 0.1002, 0.1044, 0.0976,\n",
      "         0.1097],\n",
      "        [0.1026, 0.0936, 0.1069, 0.1059, 0.0955, 0.0820, 0.1003, 0.1008, 0.1008,\n",
      "         0.1116]], grad_fn=<SoftmaxBackward0>)\n",
      "Predicted class: tensor([9, 9, 9, 9])\n"
     ]
    }
   ],
   "source": [
    "X = torch.rand(4, 28, 28, device=device)\n",
    "print(X[0].shape)\n",
    "logits = model(X)\n",
    "pred_probab = nn.Softmax(dim=1)(logits)\n",
    "print(pred_probab.size(), pred_probab)\n",
    "y_pred = pred_probab.argmax(-1)\n",
    "print(f\"Predicted class: {y_pred}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = 1e-3\n",
    "batch_size = 64\n",
    "epochs = 3\n",
    "train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)\n",
    "test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_loop(dataloader, model, loss_fn, optimizer):\n",
    "    size = len(dataloader.dataset)\n",
    "    model.train()\n",
    "    for batch, (X, y) in enumerate(dataloader, start=1):\n",
    "        X, y = X.to(device), y.to(device)\n",
    "        # Compute prediction and loss\n",
    "        pred = model(X)\n",
    "        loss = loss_fn(pred, y)\n",
    "        # Backpropagation\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if batch % 100 == 0:\n",
    "            print(\"X,y\", X.shape, y.shape)\n",
    "            loss, current = loss.item(), batch * len(X)\n",
    "            print(f\"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_loop(dataloader, model, loss_fn):\n",
    "    size = len(dataloader.dataset)\n",
    "    num_batches = len(dataloader)\n",
    "    test_loss, correct = 0, 0\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for X, y in dataloader:\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            pred = model(X)\n",
    "            test_loss += loss_fn(pred, y).item()\n",
    "            correct += (pred.argmax(dim=-1) == y).type(torch.float).sum().item()\n",
    "\n",
    "    test_loss /= num_batches\n",
    "    correct /= size\n",
    "    print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "-------------------------------\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.551137  [ 6400/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.446274  [12800/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.463588  [19200/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.625885  [25600/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.513620  [32000/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.561733  [38400/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.578524  [44800/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.512603  [51200/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.528623  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 87.0%, Avg loss: 0.364369 \n",
      "\n",
      "Epoch 2\n",
      "-------------------------------\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.440266  [ 6400/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.298903  [12800/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.420888  [19200/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.532288  [25600/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.371268  [32000/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.399205  [38400/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.640780  [44800/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.470012  [51200/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.662610  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 87.5%, Avg loss: 0.348740 \n",
      "\n",
      "Epoch 3\n",
      "-------------------------------\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.469030  [ 6400/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.374180  [12800/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.620024  [19200/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.687579  [25600/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.797273  [32000/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.429272  [38400/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.344569  [44800/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.532963  [51200/60000]\n",
      "X,y torch.Size([64, 1, 28, 28]) torch.Size([64])\n",
      "loss: 0.491500  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 87.8%, Avg loss: 0.355076 \n",
      "\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
    "\n",
    "for t in range(epochs):\n",
    "    print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "    train_loop(train_dataloader, model, loss_fn, optimizer)\n",
    "    test_loop(test_dataloader, model, loss_fn)\n",
    "print(\"Done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
